"""Langevin dynamics for concept learning with EBMs."""

import torch
import ipdb
st = ipdb.set_trace


XMIN = -1
XMAX = 1
YMIN = -0.5
YMAX = 0.5
ZMIN = -1
ZMAX = 1
SIZE = 0.15

LENS = torch.tensor(
    [XMAX - XMIN - SIZE, YMAX - YMIN - SIZE, ZMAX - ZMIN - SIZE]
) / 2.0

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


class LangevinDynamics:
    """General wrapper class for all dynamics' utilities."""

    def __init__(self, model, ld_lr=10, ld_steps=50):
        """
        Initialize instance.

        Args:
            model (nn.Module): a torch model with loaded weights
            task (str): the name of the task to solve
            ----------------------------------------------------
            ld_lr (float): 'learning rate' for langevin updates
            ld_steps (int): number of dynamics' steps
        """
        self.model = model
        self.device = next(model.parameters()).device
        self.ld_lr = ld_lr
        self.ld_steps = ld_steps

    def run(self, free_args, fixed_args=None, test=False):
        """
        Run a task-agnostic executor.

        Args:
            free_args (list): arguments that can be updated
            fixed_args (list or None): not updatable arguments
        Default expectation: model(free_args, fixed_args)
        """
        # Gradients off for EBM, on for centers
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()
        free_args = [torch.clone(arg) for arg in free_args]
        _fixed = []
        _free = []
        for arg in free_args:
            centers = arg[..., :2]
            sizes = arg[..., 2:]
            centers.requires_grad = True
            _free.append(centers)
            _fixed.append(sizes)
        free_args = _free

        # Langevin dynamics
        collect_outs = [[
            torch.cat([torch.clone(arg1.detach()), arg2], -1)
            for arg1, arg2 in zip(free_args, _fixed)
        ]]
        for step in range(self.ld_steps):
            # Add noise
            if not test:
                self._add_noise(free_args, step)

            # Forward pass: gt attention, noisy boxes
            neg_out = self.model(free_args, _fixed, fixed_args)

            # Backward pass for noisy boxes
            neg_out.sum().backward()
            self._update(free_args, step)
            collect_outs.append([
                torch.cat([torch.clone(arg1.detach()), arg2], -1)
                for arg1, arg2 in zip(free_args, _fixed)
            ])

        # Gradients on for EBM
        for param in self.model.parameters():
            param.requires_grad = True
        self.model.train()
        self.model.zero_grad()
        return collect_outs[-1], collect_outs

    def _add_noise(self, args, step=0):
        for arg in args:
            noise = torch.randn(arg.size(), device=arg.device)
            noise.normal_(0, 0.8 ** step * 0.1)
            arg.data.add_(noise.data)
        return args

    def _update(self, args, step=0):
        for arg in args:
            arg.grad.data.clamp_(-0.01, 0.01)
            arg.data.add_(-self.ld_lr * 0.95 ** step * arg.grad.data)

            # Zero gradients (do not accumulate)
            arg.grad.detach_()
            arg.grad.zero_()

            # Clamp
            arg.data = torch.max(
                torch.min(arg.data, LENS[:2].to(arg.device)),
                -LENS[:2].to(arg)
            )
        return args


def run_model(models, centers, sizes, subj, obj, rel, move_all):
    energy = 0
    for k in range(len(subj)):
        s = subj[k]
        o = obj[k]
        r = rel[k]
        if r in ['circle', 'square', 'triangle', 'line']:
            energy = energy + models[r]((
                centers[torch.as_tensor(s)][None],
                None,
                None,
                torch.ones(1, len(s)).to(DEVICE)
            ))
            continue
        for s_ in s:
            for o_ in o:
                if move_all:
                    energy = energy + models[r]((
                        centers[s_].view(1, 2),
                        sizes[s_].view(1, 2),
                        centers[o_].view(1, 2),
                        sizes[o_].view(1, 2)
                    )) / (len(s) * len(o))
                else:
                    energy = energy + models[r]((
                        centers[s_].view(1, 2),
                        sizes[s_].view(1, 2),
                        centers[o_].view(1, 2).detach(),
                        sizes[o_].view(1, 2).detach()
                    )) / (len(s) * len(o))
    return energy


def langevin(models, boxes, subj, obj, rel, move_all=True):
    boxes = torch.from_numpy(boxes).float()
    centers = boxes[..., :2].to(DEVICE)
    sizes = boxes[..., 2:].to(DEVICE)
    noise = torch.randn_like(centers).detach()
    negs_samples = []

    for _ in range(50):
        # Add noise
        noise.normal_()
        # centers = centers + 0.005 * noise

        # Forward pass
        centers.requires_grad_(requires_grad=True)
        energy = run_model(models, centers, sizes, subj, obj, rel, move_all)

        # Backward pass (gradients wrt image)
        _grad = torch.autograd.grad([energy.sum()], [centers])[0]
        centers = centers - _grad

        # Detach/clamp/store
        centers = centers.detach()
        centers[..., 0] = torch.clamp(
            centers[..., 0],
            XMIN + sizes[..., 0], XMAX - sizes[..., 0]
        )
        centers[..., 1] = torch.clamp(
            centers[..., 1],
            YMIN + sizes[..., 1], YMAX - sizes[..., 1]
        )
        negs_samples.append(torch.cat((centers, sizes), -1))

    return torch.cat((centers, sizes), -1), negs_samples
